{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch resnet18 Relay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchvision import models\n",
    "import tvm\n",
    "from tvm import relay\n",
    "\n",
    "# 创建 PyTorch 模型\n",
    "model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).eval()\n",
    "shape = [1, 3, 224, 224]\n",
    "input_infos = [(\"data\", shape)]\n",
    "# 变换为 Relay 模型\n",
    "with torch.no_grad():\n",
    "    torch_model = torch.jit.trace(model, torch.randn(shape)).eval()\n",
    "    mod, params = relay.frontend.from_pytorch(torch_model, input_infos)\n",
    "# 绑定参数并优化模型\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    mod = relay.quantize.prerequisite_optimize(mod, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 1000), float32] {\n",
      "  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "  %2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */;\n",
      "  %3 = nn.max_pool2d(%2, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::max_pool2d_0:0:0 */;\n",
      "  %4 = nn.conv2d(%3, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %5 = add(%4, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %6 = nn.relu(%5) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::relu__1:0:0 */;\n",
      "  %7 = nn.conv2d(%6, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %8 = add(%7, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %9 = add(%8, %3) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::add__0:0:0 */;\n",
      "  %10 = nn.relu(%9) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::relu__2:0:0 */;\n",
      "  %11 = nn.conv2d(%10, meta[relay.Constant][6] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %12 = add(%11, meta[relay.Constant][7] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %13 = nn.relu(%12) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::relu__3:0:0 */;\n",
      "  %14 = nn.conv2d(%13, meta[relay.Constant][8] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %15 = add(%14, meta[relay.Constant][9] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %16 = add(%15, %10) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::add__1:0:0 */;\n",
      "  %17 = nn.relu(%16) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::relu__4:0:0 */;\n",
      "  %18 = nn.conv2d(%17, meta[relay.Constant][10] /* ty=Tensor[(128, 64, 3, 3), float32] */, strides=[2, 2], padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3]) /* ty=Tensor[(1, 128, 28, 28), float32] */;\n",
      "  %19 = add(%18, meta[relay.Constant][11] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 28, 28), float32] */;\n",
      "  %20 = nn.relu(%19) /* ty=Tensor[(1, 128, 28, 28), float32] span=aten::relu__5:0:0 */;\n",
      "  %21 = nn.conv2d(%20, meta[relay.Constant][12] /* ty=Tensor[(128, 128, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3]) /* ty=Tensor[(1, 128, 28, 28), float32] */;\n",
      "  %22 = nn.conv2d(%17, meta[relay.Constant][14] /* ty=Tensor[(128, 64, 1, 1), float32] */, strides=[2, 2], padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]) /* ty=Tensor[(1, 128, 28, 28), float32] */;\n",
      "  %23 = add(%21, meta[relay.Constant][13] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 28, 28), float32] */;\n",
      "  %24 = add(%22, meta[relay.Constant][15] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 28, 28), float32] */;\n",
      "  %25 = add(%23, %24) /* ty=Tensor[(1, 128, 28, 28), float32] span=aten::add__2:0:0 */;\n",
      "  %26 = nn.relu(%25) /* ty=Tensor[(1, 128, 28, 28), float32] span=aten::relu__6:0:0 */;\n",
      "  %27 = nn.conv2d(%26, meta[relay.Constant][16] /* ty=Tensor[(128, 128, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3]) /* ty=Tensor[(1, 128, 28, 28), float32] */;\n",
      "  %28 = add(%27, meta[relay.Constant][17] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 28, 28), float32] */;\n",
      "  %29 = nn.relu(%28) /* ty=Tensor[(1, 128, 28, 28), float32] span=aten::relu__7:0:0 */;\n",
      "  %30 = nn.conv2d(%29, meta[relay.Constant][18] /* ty=Tensor[(128, 128, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3]) /* ty=Tensor[(1, 128, 28, 28), float32] */;\n",
      "  %31 = add(%30, meta[relay.Constant][19] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 28, 28), float32] */;\n",
      "  %32 = add(%31, %26) /* ty=Tensor[(1, 128, 28, 28), float32] span=aten::add__3:0:0 */;\n",
      "  %33 = nn.relu(%32) /* ty=Tensor[(1, 128, 28, 28), float32] span=aten::relu__8:0:0 */;\n",
      "  %34 = nn.conv2d(%33, meta[relay.Constant][20] /* ty=Tensor[(256, 128, 3, 3), float32] */, strides=[2, 2], padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3]) /* ty=Tensor[(1, 256, 14, 14), float32] */;\n",
      "  %35 = add(%34, meta[relay.Constant][21] /* ty=Tensor[(256, 1, 1), float32] */) /* ty=Tensor[(1, 256, 14, 14), float32] */;\n",
      "  %36 = nn.relu(%35) /* ty=Tensor[(1, 256, 14, 14), float32] span=aten::relu__9:0:0 */;\n",
      "  %37 = nn.conv2d(%36, meta[relay.Constant][22] /* ty=Tensor[(256, 256, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3]) /* ty=Tensor[(1, 256, 14, 14), float32] */;\n",
      "  %38 = nn.conv2d(%33, meta[relay.Constant][24] /* ty=Tensor[(256, 128, 1, 1), float32] */, strides=[2, 2], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 14, 14), float32] */;\n",
      "  %39 = add(%37, meta[relay.Constant][23] /* ty=Tensor[(256, 1, 1), float32] */) /* ty=Tensor[(1, 256, 14, 14), float32] */;\n",
      "  %40 = add(%38, meta[relay.Constant][25] /* ty=Tensor[(256, 1, 1), float32] */) /* ty=Tensor[(1, 256, 14, 14), float32] */;\n",
      "  %41 = add(%39, %40) /* ty=Tensor[(1, 256, 14, 14), float32] span=aten::add__4:0:0 */;\n",
      "  %42 = nn.relu(%41) /* ty=Tensor[(1, 256, 14, 14), float32] span=aten::relu__10:0:0 */;\n",
      "  %43 = nn.conv2d(%42, meta[relay.Constant][26] /* ty=Tensor[(256, 256, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3]) /* ty=Tensor[(1, 256, 14, 14), float32] */;\n",
      "  %44 = add(%43, meta[relay.Constant][27] /* ty=Tensor[(256, 1, 1), float32] */) /* ty=Tensor[(1, 256, 14, 14), float32] */;\n",
      "  %45 = nn.relu(%44) /* ty=Tensor[(1, 256, 14, 14), float32] span=aten::relu__11:0:0 */;\n",
      "  %46 = nn.conv2d(%45, meta[relay.Constant][28] /* ty=Tensor[(256, 256, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3]) /* ty=Tensor[(1, 256, 14, 14), float32] */;\n",
      "  %47 = add(%46, meta[relay.Constant][29] /* ty=Tensor[(256, 1, 1), float32] */) /* ty=Tensor[(1, 256, 14, 14), float32] */;\n",
      "  %48 = add(%47, %42) /* ty=Tensor[(1, 256, 14, 14), float32] span=aten::add__5:0:0 */;\n",
      "  %49 = nn.relu(%48) /* ty=Tensor[(1, 256, 14, 14), float32] span=aten::relu__12:0:0 */;\n",
      "  %50 = nn.conv2d(%49, meta[relay.Constant][30] /* ty=Tensor[(512, 256, 3, 3), float32] */, strides=[2, 2], padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3]) /* ty=Tensor[(1, 512, 7, 7), float32] */;\n",
      "  %51 = add(%50, meta[relay.Constant][31] /* ty=Tensor[(512, 1, 1), float32] */) /* ty=Tensor[(1, 512, 7, 7), float32] */;\n",
      "  %52 = nn.relu(%51) /* ty=Tensor[(1, 512, 7, 7), float32] span=aten::relu__13:0:0 */;\n",
      "  %53 = nn.conv2d(%52, meta[relay.Constant][32] /* ty=Tensor[(512, 512, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3]) /* ty=Tensor[(1, 512, 7, 7), float32] */;\n",
      "  %54 = nn.conv2d(%49, meta[relay.Constant][34] /* ty=Tensor[(512, 256, 1, 1), float32] */, strides=[2, 2], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]) /* ty=Tensor[(1, 512, 7, 7), float32] */;\n",
      "  %55 = add(%53, meta[relay.Constant][33] /* ty=Tensor[(512, 1, 1), float32] */) /* ty=Tensor[(1, 512, 7, 7), float32] */;\n",
      "  %56 = add(%54, meta[relay.Constant][35] /* ty=Tensor[(512, 1, 1), float32] */) /* ty=Tensor[(1, 512, 7, 7), float32] */;\n",
      "  %57 = add(%55, %56) /* ty=Tensor[(1, 512, 7, 7), float32] span=aten::add__6:0:0 */;\n",
      "  %58 = nn.relu(%57) /* ty=Tensor[(1, 512, 7, 7), float32] span=aten::relu__14:0:0 */;\n",
      "  %59 = nn.conv2d(%58, meta[relay.Constant][36] /* ty=Tensor[(512, 512, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3]) /* ty=Tensor[(1, 512, 7, 7), float32] */;\n",
      "  %60 = add(%59, meta[relay.Constant][37] /* ty=Tensor[(512, 1, 1), float32] */) /* ty=Tensor[(1, 512, 7, 7), float32] */;\n",
      "  %61 = nn.relu(%60) /* ty=Tensor[(1, 512, 7, 7), float32] span=aten::relu__15:0:0 */;\n",
      "  %62 = nn.conv2d(%61, meta[relay.Constant][38] /* ty=Tensor[(512, 512, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3]) /* ty=Tensor[(1, 512, 7, 7), float32] */;\n",
      "  %63 = add(%62, meta[relay.Constant][39] /* ty=Tensor[(512, 1, 1), float32] */) /* ty=Tensor[(1, 512, 7, 7), float32] */;\n",
      "  %64 = add(%63, %58) /* ty=Tensor[(1, 512, 7, 7), float32] span=aten::add__7:0:0 */;\n",
      "  %65 = nn.relu(%64) /* ty=Tensor[(1, 512, 7, 7), float32] span=aten::relu__16:0:0 */;\n",
      "  %66 = nn.adaptive_avg_pool2d(%65, output_size=[1, 1]) /* ty=Tensor[(1, 512, 1, 1), float32] span=aten::adaptive_avg_pool2d_0:0:0 */;\n",
      "  %67 = reshape(%66, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 512, 1, 1), float32] span=aten::flatten_0:0:0 */;\n",
      "  %68 = squeeze(%67, axis=[2, 3]) /* ty=Tensor[(1, 512), float32] span=aten::flatten_0:0:0 */;\n",
      "  %69 = nn.dense(%68, meta[relay.Constant][40] /* ty=Tensor[(1000, 512), float32] */, units=None) /* ty=Tensor[(1, 1000), float32] span=aten::linear_0:0:0 */;\n",
      "  add(%69, meta[relay.Constant][41] /* ty=Tensor[(1000), float32] */) /* ty=Tensor[(1, 1000), float32] */\n",
      "} /* ty=fn (Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] */\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(mod[\"main\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xxx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
